package dao

import (
	"errors"
	"gorm.io/gorm"
	"product-service/common"
	"product-service/dao/entity"
)

var (
	PmsCategoryDao = &pmsCategoryDao{}
)

type pmsCategoryDao struct {
}

//func (this *pmsCategoryDao) FindCategoryArray2(id int64) (p []int64) {
//	p = make([]int64, 3)
//	p = append(p, id)
//	id1, _ := this.FindById(id)
//	p = append(p, id1.ParentCid)
//	id2, _ := this.FindById(id1.CatId)
//	p = append(p, id2.CatId)
//	return p
//}

func (this *pmsCategoryDao) FindCategoryArray(db *gorm.DB, id int64) ([]int64, error) {
	sql := `SELECT
				p1.cat_id p1,
				p2.cat_id p2,
				p3.cat_id p3
			FROM
				pms_category p1
				LEFT JOIN pms_category p2 ON p1.cat_id = p2.parent_cid
			LEFT JOIN pms_category p3 ON p2.cat_id=p3.parent_cid
			WHERE p3.cat_id=?
		`
	row := db.Raw(sql, id).Row()
	p := make([]int64, 3)
	err := row.Scan(&p[0], &p[1], &p[2])
	if err != nil {
		return nil, err
	}
	return p, nil
}

func (this *pmsCategoryDao) FindById(db *gorm.DB, id int64) (*entity.PmsCategory, error) {
	model := &entity.PmsCategory{}
	db.Find(model, id)
	if model.CatId == 0 {
		return nil, errors.New("无此用户")
	}
	return model, nil
}

func (this *pmsCategoryDao) FindAll(db *gorm.DB, query *common.PageQuery) ([]*entity.PmsCategory, int64, error) {
	tx := db.Scopes(Page(int(query.Page), int(query.Size)))

	var models []*entity.PmsCategory
	//直接拼接所有条件
	model := query.Model.(*entity.PmsCategory)

	//count不包含limit条件
	var count int64
	tx.Model(&model).Count(&count)

	err := tx.Where(model).Find(&models).Error
	if err != nil {
		return nil, 0, err
	}
	return models, count, nil
}

func (this *pmsCategoryDao) DeleteById(db *gorm.DB, id int64) error {
	tx := db.Delete(&entity.PmsCategory{}, id)
	if tx.Error != nil {
		return errors.New("无效删除")
	}
	return nil
}

func (this *pmsCategoryDao) UpdateById(db *gorm.DB, id int64, model *entity.PmsCategory) error {
	model.CatId = id
	tx := db.Updates(model)
	if tx.Error != nil {
		return errors.New("无效更新")
	}
	return nil
}

func (this *pmsCategoryDao) Save(db *gorm.DB, model *entity.PmsCategory) (*entity.PmsCategory, error) {
	tx := db.Create(model)
	if tx.RowsAffected == 0 {
		return nil, errors.New("保存失败")
	}
	return model, nil
}

func (this *pmsCategoryDao) FindTree(db *gorm.DB) ([]*entity.PmsCategory, error) {
	var models []*entity.PmsCategory
	err := db.Find(&models).Error
	if err != nil {
		return nil, err
	}
	//循环处理
	models = recursionModels(models)
	return models, nil
}

func (this *pmsCategoryDao) UpdateBatch(db *gorm.DB, models []*entity.PmsCategory) error {
	newTx := db.Begin()
	//遍历更新
	for _, model := range models {
		//批量更新失败,那么直接返回,并且回滚之前的数据
		//这里会自动识别组件id
		err := newTx.Updates(model).Error
		if err != nil {
			newTx.Rollback()
			return err
		}
	}
	newTx.Commit()
	//但是如果批量更新出错,那么只能返回
	return nil
}

func (this *pmsCategoryDao) DeleteBatch(db *gorm.DB, ids []int64) error {
	newTx := db.Begin()
	err := newTx.Delete(&entity.PmsCategory{}, ids).Error
	if err != nil {
		newTx.Rollback()
		return err
	}
	newTx.Commit()
	return nil
}

//循环整理数据
func recursionModels(models []*entity.PmsCategory) []*entity.PmsCategory {
	var newModels []*entity.PmsCategory
	for _, model := range models {
		if model.ParentCid == 0 {
			newModels = append(newModels, model)
			for _, model2 := range models {
				if model.CatId == model2.ParentCid {
					model.Children = append(model.Children, model2)
					for _, model3 := range models {
						if model2.CatId == model3.ParentCid {
							model2.Children = append(model2.Children, model3)
						}
					}
				}
			}
		}
	}
	return newModels
}
